# 此代码用于实现对Astar轨迹的剪枝优化
# 对第一轮生成的航点进行检测，去除不必要的航点，得到最终航点，并返回实际位置
# import plot_module as plot
# import world_set as world
# import matplotlib.pyplot as plt
import numpy as np
# from PIL import Image

map_grid_ = np.load('npy/resized_map.npy')
min_straight = 5    # 最小的直线段长度


# 输入两个航点(前后）, 输出航向
def get_direction(point1, point2):
    """输入路径上相邻的格点，返回前进方向"""
    dx = point2[0] - point1[0]
    dy = point2[1] - point1[1]
    dir_index = int(3*dx + 4 + dy)

    dir_list = ['southwest', 'west', 'northwest', 'south', 'unmoved', 'north', 'southeast', 'east', 'northeast']
    direction = dir_list[dir_index]

    return direction


def breakpoint_detection(path_points):
    """输入航迹，返回航迹的折点, 包括起点和终点, 航迹段落的长度，方向

    :param path_points: 航迹
    :return: 折点坐标包含起点终点[[x0,y0], [x1,y1]]， 每段长度=[1,2,3],每段方向[''north', 'northeast']
    """
    # 路径倒叙，实现从起点到终点
    path = path_points[::-1]
    # 主要部分
    break_points = [path[0]]
    last_dir = get_direction(path[0], path[1])
    direction_list = [last_dir]
    for i in range(2, len(path)):
        this_dir = get_direction(path[i-1], path[i])
        if this_dir != last_dir:
            break_points.append(path[i-1])
            last_dir = this_dir
            direction_list.append(this_dir)

    break_points.append(path[-1])

    # 计算线段长度
    length_line = []
    for i in range(1, len(break_points)):
        if direction_list[i-1] in ['southwest', 'southeast', 'northwest', 'northeast']:
            length = abs(break_points[i][0] - break_points[i-1][0]) * 1.414
        else:
            length = max(abs(break_points[i][0] - break_points[i-1][0]),
                         abs(break_points[i][1] - break_points[i-1][1])) * 1.0
        length_line.append(length)

    break_points = np.array(break_points)

    # 返回值是折点，包含起点和终点， 长度为k，  线段长度和线段方向的列表 的长度为k-1
    return break_points, length_line, direction_list


def apf_segment(lines):
    """根据输入的断点和断点两端方向，确定人工势场法的作用开始和结束点

    :param lines: 断点列表, 包含折线段的n个端点（包含起点终点）、n-1个线段的长度和方向
    :return: 人工势场作用域的开始和结束点
    """
    points = lines[0]
    distance = lines[1]
    direction = lines[2]

    # 判断采用人工势场的起点和终点
    start_points = []
    end_points = []

    tmp = enumerate(distance)
    for i, dis in tmp:
        if dis < min_straight:
            start_points.append([points[i], i, direction[i]])

            # 终点的确定
            while distance[i+1] < min_straight:
                next(tmp)
                i += 1
                if i+1 == len(distance):
                    break

            end_points.append([points[i + 1], i+1])

    # 对每一个折点，起点前移，终点后移
    dir_set = {'southwest': [-3, -3], 'west': [-4, 0], 'northwest': [-3, 3], 'south': [0, -4],
               'unmoved': [0, 0], 'north': [4, 0], 'southeast': [3, -3], 'east': [4, 0], 'northeast': [3, 3]}
    for point, i, nonsense in start_points:
        dx, dy = dir_set[direction[i-1]]
        point[0] -= dx
        point[1] -= dy

    for point, i in end_points:
        if point == points[-1]:
            break
        dx, dy = dir_set[direction[i]]
        point[0] += dx
        point[1] += dy

    return start_points, end_points


def delete_dot(break_dot_, map_grid=map_grid_):
    """检测一次路径的折点是否可以省略，输入拐点，返回拐点（包括开始结束点）in=out=[[x1,y1],[x,y]]"""
    class Dot:
        def __init__(self, x, y):
            self.x = x
            self.y = y
            self.delete = False

    break_dot = []
    for dot in break_dot_:
        break_dot.append(Dot(dot[0], dot[1]))

    for end_index in range(2, len(break_dot)):
        start_dots = break_dot[:end_index-1]  # 从第一个点到监测点前两个的点

        for start_index, start_dot in enumerate(start_dots):
            if start_dot.delete:
                continue

            crash_flag = is_crash([start_dot.x, start_dot.y],
                                  [break_dot[end_index].x, break_dot[end_index].y], map_grid)
            if crash_flag:
                continue   # 起始点变为下一点
            else:
                for i in range(start_index+1, end_index):
                    break_dot[i].delete = True
                break

    final_dots = [(dot.x, dot.y) for dot in break_dot if not dot.delete]
    return final_dots


def is_crash(dot1, dot2, map_grid=map_grid_):
    """判断两点连线和全局格点地图中障碍物点是否相交"""
    # 地图由全局变量map_grid 给出
    crash_flag = False
    pass_grids = []

    def float_range(a, b, c):
        return range(int(a), int(b), int(c))

    if dot2[0] == dot1[0]:
        # range函数不会自己判断符号
        for i in range(dot1[1]+1, dot2[1], np.sign(dot2[1]-dot1[1])):
            pass_grids.append((dot1[0], i))
            if map_grid[dot1[0]][i] == 0:
                return True
        return False

    if dot2[1] == dot1[1]:
        for i in range(dot1[0]+1, dot2[0], np.sign(dot2[0] - dot1[0])):
            pass_grids.append((i, dot1[1]))
            if map_grid[i][dot1[1]] == 0:
                return True
        return False

    else:
        # 两点连线是斜线，穿过多个格子
        # 表达式 y-0.5 = k(x-0.5)  0.5是格点中心
        k = (dot2[1] - dot1[1]) / (dot2[0] - dot1[0])

        def get_y(x_value):
            y_value = k*(x_value - dot1[0]) + dot1[1]
            return y_value

        y_list = [np.floor(dot1[1])]
        x_sign = int(np.sign(dot2[0] - dot1[0]) )
        y_sign = int(np.sign(dot2[1] - dot1[1]) )
        for x in float_range(dot1[0]+x_sign, dot2[0], x_sign):
            y_list.append(np.floor(get_y(x)))

            y_start = y_list[-2]
            y_end = y_list[-1]
            for y in float_range(y_start, y_end+y_sign, y_sign):
                pass_grids.append((x-x_sign, y))
                if map_grid[x-x_sign][y] == 0:
                    return True
        # 考虑x = dot2[0]
        y_start = np.floor(y_list[-1])
        y_end = np.floor(dot2[1])
        for y in float_range(y_start, y_end+y_sign, y_sign):
            pass_grids.append((dot2[0], y))
            if map_grid[dot2[0]][y] == 0:
                return True

        return crash_flag


def actual_loc(path_dots, scale):
    """将格点路径转化为真实路径，输入路径点和比例尺，返回真实世界航点坐标（以格点中心计）"""
    actual_dot = []
    for dot in path_dots:
        x = int((dot[0]+0.5) * scale)
        y = int((dot[1]+0.5) * scale)
        actual_dot.append([x, y])

    return actual_dot


def use_method():
    # 使用范例
    path = np.load('npy/raw_path.npy')
    breakpoints, lengths, dirs = breakpoint_detection(path)
    final_dots = delete_dot(breakpoints)
    print("path: ", path.transpose())
    print("breakpoints: ", breakpoints.transpose())
    print("final_dots: ", final_dots)


    # print('path: ', path)
    # print('breakpoints: ', breakpoints)
    # print('lengths: ', lengths)
    # print('dirs: ', dirs)

    # next_dots = detect_second(breakpoints)
    # actual_locate = actual_loc(next_dots, scale=10)
    # np.save('actual_dots_loc.npy', actual_locate)
    #
    # # 绘图
    # image = Image.open('binary.png')
    # # 将图片转换为RGB模式
    # image = image.convert('RGB')
    # for dot in path:
    #     x = dot[0]
    #     y = dot[1]
    #     image.putpixel((y, x), (255, 130, 71))
    #
    # for dot in breakpoints:
    #     x = dot[0]
    #     y = dot[1]
    #     image.putpixel((y, x), (255, 0, 0))
    #
    # for dot in next_dots:
    #     x = dot[0]
    #     y = dot[1]
    #     image.putpixel((y, x), (0, 255, 0))
    #
    # image.show()
    #
    # print('breakpoints', breakpoints)
    # print('next_dots=', next_dots)


def test(num):
    if num == 1:
        use_method()
    else:
        a = is_crash([16, 59], [20, 20])
        print(a)


if __name__ == '__main__':
    map = np.load('npy/resized_map.npy')
    is_crash([54,60],[50,70],map)
